/*    Copyright 2010 Tobias Marschall
 *
 *    This file is part of MoSDi.
 *
 *    MoSDi is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    MoSDi is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoSDi.  If not, see <http://www.gnu.org/licenses/>.
 */

package mosdi.subcommands;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.StringTokenizer;

import mosdi.discovery.TextModelBounds;
import mosdi.distributions.PoissonDistribution;
import mosdi.fa.Alphabet;
import mosdi.fa.CDFA;
import mosdi.fa.DFAFactory;
import mosdi.fa.DoublePatternExpectationCalculator;
import mosdi.fa.FiniteMemoryTextModel;
import mosdi.fa.ForwardPatternExpectationCalculator;
import mosdi.fa.IIDTextModel;
import mosdi.fa.MarkovianTextModel;
import mosdi.fa.PatternExpectationCalculator;
import mosdi.paa.ClumpSizeCalculator;
import mosdi.util.BitArray;
import mosdi.util.FileUtils;
import mosdi.util.Iupac;
import mosdi.util.Log;
import mosdi.util.LogSpace;
import mosdi.util.NamedSequence;
import mosdi.util.SequenceUtils;

public class CalcScoresSubcommand extends Subcommand {
	
	@Override
	public String usage() {
		return
		super.usage()+" [options] <fasta-file> <iupac-pattern>\n" +
		"\n" +
		"Options:\n" +
		"  -i: replace unknown characters in DNA sequences by $\n" +
		"  -F: read patterns from file\n" +
		"  -r: simultaneously consider reverse complementary motif\n" +
		"  -T <truth-file>: Given a file containing the true answers,\n" +
		"                   calculates the statistics nTP, nFP, nFN, nTN, nSn, nPPV, nSp, nPC, nCC\n" +
		"                   (in that order) as used in Tompa et al., Nat. Biotechnol.(23), 137-144, 2005.\n" +
		"                   Format: <seq_nr>,<pos>,<string>\n" +
		"                   Where <seq_nr> starts at 0 and <pos>=-1 means the last nucleotide.\n" +
		"  -M <order>: use Markovian text model of order <order>. Default: 0 (i.i.d.)\n" +
		"  -q <q-gram-table-file>: Estimate text model from given q-gram table\n" +
		"                          (default: estimate from sequences). A q-gram table can be\n" +
		"                          created by using \"mosdi-utils count-qgrams\".";
	}
	
	@Override
	public String description() {
		return
		"Calculates scores for given motifs. Useful for re-evaluation " +
		"on other sequences or using other text models. Additionally, given the true " +
		"answers, calculates some statistics.";
	}

	@Override
	public String name() {
		return "calc-scores";
	}

	@Override
	public int run(String[] args) {
		// TODO: Implement for different objectives and use ObjectiveFunction.staticEvaluate()
		parseOptions(args, 2, "FrT:M:iq:");

		// Option dependencies
		exclusiveOptions("q", "M");

		// Mandatory arguments
		String sequenceFilename = getStringArgument(0);
		String pattern = getStringArgument(1);

		// Options
		Alphabet iupacAlphabet = Alphabet.getIupacAlphabet();
		Alphabet dnaAlphabet = Alphabet.getDnaAlphabet();
		List<NamedSequence> namedSequences = null;
		boolean considerReverse = getBooleanOption("r", false);
		int textModelOrder = getNonNegativeIntOption("M", 0);
		boolean replaceUnknownByMinusOne = getBooleanOption("i", false);
		String qGramFilename = getStringOption("q", null);
		boolean readPatternsFromFile = getBooleanOption("F", false);
		String truthFilename = getStringOption("T", null);
		
		// create list of patterns
		List<String> patternList = null;
		if (readPatternsFromFile) {
			patternList = FileUtils.readPatternFile(pattern);
		} else {
			patternList = new ArrayList<String>(1);
			patternList.add(pattern);
		}
		Log.startTimer();
		try {
			namedSequences = SequenceUtils.readFastaFile(sequenceFilename, dnaAlphabet, replaceUnknownByMinusOne);
		} catch (Exception e) {
			Log.errorln("Error reading "+sequenceFilename+":");
			Log.errorln(e.toString());
			System.exit(1);
		}
		Log.stopTimer("Reading FASTA file");
		List<int[]> sequences = new ArrayList<int[]>(namedSequences.size());
		int[] sequenceLengths = new int[namedSequences.size()];
		int n = 0;
		for (NamedSequence ns : namedSequences) {
			sequences.add(ns.getSequence());
			sequenceLengths[n++] = ns.length();
		}
		FiniteMemoryTextModel textModel = null;
		if (qGramFilename!=null) {
			try {
				textModel = SequenceUtils.buildTextModelFromQGramFile(qGramFilename);
			} catch (Exception e) {
				Log.errorln("Error reading "+qGramFilename+":");
				Log.errorln(e.toString());
				System.exit(1);
			}
		} else {
			if (textModelOrder==0) {
				textModel = new IIDTextModel(dnaAlphabet.size(), sequences);
			} else {
				textModel = new MarkovianTextModel(textModelOrder, dnaAlphabet.size(), sequences);
			}
		}
		// total sequence length
		int sequenceLength = 0;
		for (int[] s : sequences) sequenceLength+=s.length;
		
		// read truth file
		// contains a BitArray for each sequence
		BitArray[] truth = null;
		if (truthFilename!=null) truth = CalcScoresSubcommand.readTruthFile(truthFilename, dnaAlphabet, sequences);		
		
		TextModelBounds textModelBounds = new TextModelBounds(textModel, Iupac.asGeneralizedAlphabet());
		TextModelBounds backwardTextModelBounds = new TextModelBounds(textModel.reverseTextModel(), Iupac.asGeneralizedAlphabet());
		
		Log.println(Log.Level.STANDARD, "Format: >> pattern >occurrence_count> #matches minus-log-pvalue pvalue expectation >sequence_count> #matching-sequences minus-log-pvalue pvalue expectation >stats> #dfa-states, #dfa-states-minimized #paa-states expectation-per-pos expected-clump-size >model_order> text-model-order");
		
		Log.startTimer();
		for (String forwardPattern : patternList) {
			Log.startTimer();
			Log.startTimer();
			CDFA cdfa = DFAFactory.buildFromIupacPattern(forwardPattern, considerReverse, 50000);
			int states = cdfa.getStateCount();
			cdfa = cdfa.minimizeHopcroft();
			int statesMinimal = cdfa.getStateCount();

			// count matches
			int totalMatches = 0;
			int matchingSequences = 0;
			for (int i=0; i<namedSequences.size(); ++i) {
				NamedSequence ns = namedSequences.get(i);
				int x = cdfa.countMatches(ns.getSequence(), true);
				// Log.printf(Log.Level.DEBUG, "Found %d matches in sequence \"%s\"\n", x, ns.getName());
				// Log.println(Log.Level.DEBUG, dnaAlphabet.buildString(ns.getSequence()));
				totalMatches += x;
				matchingSequences += (x>0)?1:0;
			}

			Log.startTimer();
			PatternExpectationCalculator pec;
			if (considerReverse) {
				pec = new DoublePatternExpectationCalculator(textModel, iupacAlphabet.buildIndexArray(forwardPattern), textModelBounds, backwardTextModelBounds);
			} else {
				pec = new ForwardPatternExpectationCalculator(textModel, Iupac.asGeneralizedAlphabet(), iupacAlphabet.buildIndexArray(forwardPattern), textModelBounds);
			} 
			double singleExpectation = pec.getExpectation();
			double expectation = 0.0;
			for (int[] s : sequences) {
				expectation+=singleExpectation*(s.length-forwardPattern.length()+1);
			}
			Log.stopTimer("Computing expectation");

			// calculate clump size distribution
			ClumpSizeCalculator csc = new ClumpSizeCalculator(textModel, cdfa, forwardPattern.length());
			int statesProduct = csc.getProductStateCount();
			double[] clumpSizeDist = csc.clumpSizeDistribution(30, 1e-300);
			//double[] clumpSizeDist = csc.clumpSizeDistribution(8, 1e-30);
			Log.restartTimer("calculate clump size distribution");
			// double timeClumpSizeDist = Log.getLastPeriodCpu();
			
			// calculate expected clump size
			double expectedClumpSize = 0.0;
			for (int i=1; i<clumpSizeDist.length; ++i) {
				expectedClumpSize+=clumpSizeDist[i]*i;
			}
			
			double[] dist = SequenceUtils.calculateSequenceCountDistribution(sequenceLengths, singleExpectation, forwardPattern.length(), expectedClumpSize, false);
			double expectedSeqCount = 0.0;
			for (int i=0; i<dist.length; ++i) expectedSeqCount += dist[i]*i;
			double seqPvalue = 0.0;
			for (int i=matchingSequences; i<dist.length; ++i) {
				seqPvalue += dist[i];
			}
			Log.restartTimer("calculate sequence count distribution");
			
			double lambda = expectation/expectedClumpSize;
			PoissonDistribution poissonDist = new PoissonDistribution(lambda);
			double totalOccPvalue = poissonDist.compoundPoissonPValue(clumpSizeDist, totalMatches);
			double totalOccLogPvalue;
			if (totalOccPvalue>0.0) {
				totalOccLogPvalue = Math.log(totalOccPvalue);
			} else {
				totalOccLogPvalue = poissonDist.logCompoundPoissonPValue(clumpSizeDist, totalMatches);
			}
//			// if double precision is not sufficient, repeat calculation in 
//			// logarithmic domain
//			if (pvalue==0.0) {
//				Log.println(Log.Level.DEBUG, "Repeating calculation in logarithmic domain.");
//				logarithmic=true;
//				clumpSizeDist = mac.clumpSizeDistribution(16, patternLength, 1e-300);
//				if (onlyPValue) {
//					pvalue = poissonDist.logCompoundPoissonPValue(clumpSizeDist, max_matches);
//				} else {
//					dist = poissonDist.logCompoundPoissonDistribution(clumpSizeDist, max_matches);
//					pvalue = dist[max_matches];
//				}
//			}
			Log.stopTimer("p-value calculations");
			// double timePValues = Log.getLastPeriodCpu();

			Log.stopTimer("Total time for this motif");
			// double timeMotif = Log.getLastPeriodCpu();

			StringBuilder sb = new StringBuilder();
			sb.append(String.format(">> %s >occurrence_count> %d %e %s %e ", forwardPattern, totalMatches, -totalOccLogPvalue, LogSpace.toString(totalOccLogPvalue), expectation));
			sb.append(String.format(">sequence_count> %d %e %e %e ", matchingSequences, -Math.log(seqPvalue), seqPvalue, expectedSeqCount));
			sb.append(String.format(">stats> %d %d %d %e %e ", states, statesMinimal, statesProduct, singleExpectation, expectedClumpSize));
			sb.append(String.format(">model_order> %d ", textModel.order()));
			// calculate some statistics if truth is known
			if (truth!=null) {
				// positives/negatives in sample
				double nP = 0;
				double nN = 0;
				double totalLength = 0;
				for (BitArray ba : truth) {
					totalLength+=ba.size();
					nP+=ba.numberOfOnes();
				}
				nN=totalLength-nP;
				// true/false positives
				double nTP = 0;
				double nFP = 0;
				for (int i=0; i<namedSequences.size(); ++i) {
					// keep track of already-seen positions
					Set<Integer> posSet = new HashSet<Integer>();
					for (CDFA.MatchPosition mp : cdfa.findMatchPositions(sequences.get(i))) {
						for (int j=mp.getPosition()-forwardPattern.length()+1; j<=mp.getPosition(); ++j) {
							if (posSet.contains(j)) continue;
							posSet.add(j);
							if (truth[i].get(j)) nTP+=1;
							else nFP+=1;
						}
					}
				}
				double nTN = nN - nFP;
				double nFN = nP - nTP;
				double nSn = nTP/nP;
				double nSp = nTN/nN;
				double nPPV = nTP/(nTP + nFP);
				double nPC = nTP/(nTP + nFN + nFP);
				double nCC = (nTP*nTN - nFN*nFP)/Math.sqrt((nTP+nFN)*(nTN+nFP)*(nTP+nFP)*(nTN+nFN));
				sb.append(String.format(">>truth_stats>> %d %d %d %d %f %f %f %f %f ", (int)nTP, (int)nFP, (int)nFN, (int)nTN, nSn, nPPV, nSp, nPC, nCC));
			}

//			sb.append(String.format(">>poisson>> %e %e ", lambda, expectedClumpSize));
//			sb.append(String.format(">>runtimes>> %e %e %e ", timeClumpSizeDist, timeConvolution, timeMotif));
//			if (pValueTable) {
//				sb.append(">>p_value_table>> ");
//				for (double d : dist) sb.append(String.format("%e ", d));
//			}
			if (sb.length()>0) Log.println(Log.Level.STANDARD, sb.toString());
		}
		Log.stopTimer("Total time");
		double timeTotal = Log.getLastPeriodCpu();
		Log.printf(Log.Level.STANDARD, ">>!>total_time>> %t %n", timeTotal);
		return 0;
	}

   /** Read file that contains the positions of true binding sites. */
   public static BitArray[] readTruthFile(String filename, Alphabet alphabet, List<int[]> sequences) {
   	BitArray[] truth = new BitArray[sequences.size()];
   	for (int i=0; i<sequences.size(); ++i) {
   		truth[i]=new BitArray(sequences.get(i).length);
   	}
   	FileInputStream file = null;
   	try {
   		file = new FileInputStream(filename);
   	} catch (FileNotFoundException e) {
   		Log.errorln("Truth-file not found, sorry!");
   		System.exit(1);
   	}
   	BufferedReader br = new BufferedReader(new InputStreamReader(file));
   	try {
   		while (true) {
   			String line = br.readLine();
   			if (line==null) break;
   			StringTokenizer st = new StringTokenizer(line, ",", false);
   			if (st.countTokens()<3) {
   				Log.errorln("Truth-file: invalid format");
   				System.exit(1);
   			}
   			int seqNr = Integer.parseInt(st.nextToken());
   			int pos = sequences.get(seqNr).length+Integer.parseInt(st.nextToken());
   			int[] s = alphabet.buildIndexArray(st.nextToken().toUpperCase());
   			if (!Arrays.equals(Arrays.copyOfRange(sequences.get(seqNr), pos, pos+s.length), s)) {
   				Log.errorln("Truth-file: sequence mismath");
   				System.exit(1);
   			}
   			for (int i=pos; i<pos+s.length; ++i) truth[seqNr].set(i, true);
   		}
   	} catch (IOException e) {
   		Log.errorln("I/O failure while reading truth-file, sorry!");
   		System.exit(1);
   	}
   	return truth;
   }
}
